#!/usr/bin/python
import parser
import symbol
import token
import sys
import os

################################################################################
# Basically Obscure Assembler
################################################################################
#  8 Sep 2003 - AT&T style initial version
# 15 Sep 2003 - NASM style memory references
# 18 Sep 2003 - NASM style pseudo-ops; also comparisons, (op)= assignments
# 20 Sep 2003 - NASM style string initializers, ord('x') for character constants
# 21 Sep 2003 - NASM argument order; dx(), resx(), string literals
# 18 Oct 2003 - offset(), incbin()
################################################################################
# TODO: find how to get more informative parser errors
################################################################################

#-------------------------------------------------------------------------------
# add opcodes to this list to enable the assembler to recognize them 
#-------------------------------------------------------------------------------

opcodes = """
    add andl cli cld cmp dec inb inc int inw lea lodsb lodsw lodsl
    movsb movsl movl nop orl outb outw pop popa push pusha rep ret
    shl shr std sti stosb stosw stosl sub test xlat xor
	  """.split()

reversed = """
    add andl cmp inb inw lea movl orl outb outw shl shr sub test xor
	   """.split()

#-------------------------------------------------------------------------------
# the python parse tree is very deeply nested, so we pull out the values in
# which we are interested, and leave them on a token stack.
#-------------------------------------------------------------------------------

toks=[]

def tokreset():
	""" reset the token stack """
	global toks
	toks=[]

def tokpeek(pat):
	""" \(pat): does pat match the type of the token stack top? """
	return toks != [] and toks[-1][0] == pat

# other ancillary functions ----------------------------------------------------

def arg(t):
	""" return at&t style string coding an argument token """
	(x,y) = t
	if x == 'memref':	
		if y[2]:	return "%s(%s,%s,%s)" % y
		elif y[1]:	return "%s(%s)"       % y[:2]
		else:		return "%s"           % y[:1]
	elif x == 'reg':	return y
	else:			return "$" + y

def size(t):
	""" return size byte for at&t """
	(x,y) = t
	if x == 'reg':	return ''
	else:		return 'l'

def cint(s):
	""" return value of a C-style numeric constant """
	if s[0] == '0':
		if len(s) == 1:		return 0
		elif s[1] == 'x':	return int(s[2:], 16)
		else:			return int(s[1:], 8)
	else:				return int(s)

def glob(s):
	return ".globl %s\n%s" % (s,s)
def section(s):
	print ".text %d" % (s%256)

#-------------------------------------------------------------------------------
# pseudo-ops have a d/w/b suffix, for 32/16/8 bit values
#	dd, dw, db:		initialized data
#	incbin:			initialized data from external file
#	resd, resw, resb:	reserve space (optional arguments multiply)
#	align:			align to 
#
#	STRINGS can also be given as initializers, they are placed in memory
#	and padded out to the appropriate alignment
#-------------------------------------------------------------------------------

pseudo = { 'dd':'long', 'dw':'word', 'db':'byte', 'align':'align' }
res    = { 'resd':4,    'resw':2,    'resb':1 }

def ginits(op, argl):
	""" generate initializer list """
	for t, v in argl:
		if t == token.STRING:
			align = {'dd':4,'dw':2,'db':1,'align':1}
			print ".ascii \"%s\"" % v[1:-1]
			print ".align %d,0" % align[op]
		else:
			print ".%s %s" % (pseudo[op], v)

def gpseudo(op, argl, s, c):
	""" handle nested literal pseudo-ops """
	(code, sect, flags) = c
	argl.reverse()

	if code: nested = len(toks) > 0
	else:	 nested = len(toks) > 1

	if nested: section(sect+2)	# don't use .text 0 for literals
	if pseudo.has_key(op):
		if s and op != 'align': print "%s:" % s
		ginits(op, argl)
		if s and op == 'align': print "%s:" % s
	elif res.has_key(op) and s:
		n = reduce(lambda v,t: v * cint(t[1]), argl, res[op])
		print ".comm %s,%s" % (s,n)
	elif op == 'incbin':
		if s: print "%s:" % s
		(t, v) = argl[0]
		f = open(os.path.join(dir, v[1:-1]), "r")
		while 1:
			l = f.read(16)
			if not len(l):
				break
			print ".ascii \"%s\"" % \
				repr(l)[1:-1].replace('"',r'\"')
		f.close()
	if nested: section(sect)

#-------------------------------------------------------------------------------
# code is either:
#	opcode(args,...)
#	a (op)= b		(assignment)
#	unknown text is treated as a function call ("C" conventions)
#	pseudo-op(inits,...)
#
#	STRINGs are ignored, to allow fake docstrings
#-------------------------------------------------------------------------------

if 'PLUSEQUAL' in dir(token):
	equals = {
	token.EQUAL:		'mov',
	token.PLUSEQUAL:	'add',
	token.MINEQUAL:		'sub',
	token.AMPEREQUAL:	'and',
	token.VBAREQUAL:	'or',
	token.LEFTSHIFTEQUAL:	'shl',
	token.RIGHTSHIFTEQUAL:	'shr' }
else:
	equals = {token.EQUAL: 'mov'}

def gcode():
	""" generate code sequences """
	if tokpeek('list'):	argl = toks.pop()[1]
	else:			argl = []
	(x,y) = toks.pop()
	if x == token.NAME:
		if y == "pass":		pass
		elif y in opcodes:
			if y in reversed: argl.reverse()
			def catargs(str, tok):
				if str: return arg(tok) + "," + str
				else:	return arg(tok)
			print y, reduce(catargs, argl, "")
		else:
			for l in argl:
				print "push", arg(l)
			print "call", y
			if len(argl):
				print "lea %d(%%esp),%%esp" % (4 * len(argl))
	elif equals.has_key(x):
		a = toks.pop()
		b = toks.pop()
		print "%s%s %s,%s" % (equals[x], size(b), arg(a), arg(b))
	elif x == token.STRING:	pass
	else:
		toks.append((x,y))
		print ".err", toks

#-------------------------------------------------------------------------------
# data is either:
#	symbol = NUMBER
#	symbol = pseudo-op(inits,...)
#	pseudo-op(inits,...)
#
#	STRINGs are ignored, to allow fake docstrings
#-------------------------------------------------------------------------------

def gdata():
	""" generate data sequences """
	# pseudo-ops are already handled by the parser,
	# so all we have to do here is generate EQU's
	if tokpeek(token.EQUAL):
		toks.pop()
		(_, y) = toks.pop()
		(_, s) = toks.pop()
		print "%s = %s" % (glob(s), y)
	elif tokpeek(token.STRING):	pass
	else:
		print ".err", toks

#-------------------------------------------------------------------------------
# conditionals can have the forms:
#	if flag:
#	if foo is flag:
#	if foo == bar:
#-------------------------------------------------------------------------------

labc=0
def nextlab():
	"""  returns a unique number """
	global labc
	labc = labc + 1
	return labc

cmps = {token.LESS:		'b',
	token.LESSEQUAL:	'be',
	token.EQEQUAL:		'z',
	token.NOTEQUAL:		'nz',
	token.GREATEREQUAL:	'ae',
	token.GREATER:		'a' }

def gcond():
	"""  generates predicate code, and returns the j<cond> opcode """
	(x,flag) = toks.pop()
	if flag == 'is':
		(_,flag) = toks.pop()
		gcode()
	elif cmps.has_key(x):
		a = toks.pop()
		b = toks.pop()
		print "cmpl %s,%s" % (arg(a),arg(b))
		flag = cmps[x]
	tokreset()
	return "j" + flag


#-------------------------------------------------------------------------------
# we rewrite NASM-like expressions of the form [base + index*scale + disp]
# to tokens of the form ('memref', disp, base, index, scale)
#-------------------------------------------------------------------------------

def mref(tup):
	""" coerce the top of stack to memref form """
	(x,y) = tup
	if x == 'memref':	pass
	elif x == 'reg':	y = ('', y, '', '1')
	else:			y = (y, '', '', '1')
	return ('memref', y)

def mrewrite(tup, c):
	""" handle the interior nodes of a memref parse """
	# [base + index + off]
	if len(tup) == 6:
		prtup(tup[1], c)
		prtup(tup[3], c)
		prtup(tup[5], c)
		(_,a) = toks.pop()
		(t,b) = toks.pop()
		(_,c) = toks.pop()
		if t != 'memref': b = (a, c, b,   '1'  )
		else:		  b = (a, c, b[2], b[3])
		toks.append(('memref', b))
		return 1
	# index * scale, [base + index], [index + off], [base + off]
	elif len(tup) == 4:	
		prtup(tup[1], c)
		prtup(tup[3], c)
		prtup(tup[2], c)
		(_,a) = toks.pop()
		(t,b) = toks.pop()
		(u,c) = toks.pop()
		if a == '*':		y = ('',   '',   c,    b   )
		elif t == 'memref':	y = (b[0], c,    b[2], b[3])
		elif t == 'reg':	y = ('',   b,     c,   '1' )
		elif u == 'memref':	y = (b,    c[1], c[2], c[3])
		else:			y = (b,    c,    '',   '1' )
		toks.append(('memref', y))
		return 1
	else:	return 0

#-------------------------------------------------------------------------------
# rewrite registers to ('reg', %name)
# if strings are not on the top level, nor in a special construct,
#     rewrite them to be a db(string) literal
#-------------------------------------------------------------------------------

def prterm(tup,c):
	""" process terminals """
	(x,y) = tup
	(code, sect, flags) = c
	if x == token.NAME and y in ["eax","ebx","ecx","edx",
				     "ebp","esp","esi","edi",
				     "ax","bx","cx","dx",
				     "al","ah","bl","bh",
				     "cs","ds","es","fs","gs","ss"]:
		tup = ('reg', "%" + tup[1])
	elif x == token.STRING and len(toks):
		if flags not in ['db', 'dw', 'dd', 'ord', 'asm', 'incbin']:
			litsym = "l%d" % nextlab()
			gpseudo('db', [(token.NUMBER, 0), tup], litsym, c)
			tup = (token.NAME, litsym)
	return tup

#-------------------------------------------------------------------------------
# collect argument lists into a python list
# ord(STRING) is rewritten to be a NUMBER character constant
# asm(STRING) is an escape to the underlying assembler
# offset(NUM) assembles an offset from the current location ('.')
# pseudo-ops are written out and, if literals, rewritten to their address
#-------------------------------------------------------------------------------

def prlist(op, c):
	""" process an argument list """
	l = []
	while 1:
		(x, y) = toks.pop()
		if x == token.COMMA:    pass
		elif x == token.LPAR:   break
		else:                   l.append((x,y))

	def sarg(l):
		toks.pop()
		(t, v) = l[0]
		return v[1:-1]

	if op == 'ord':
		toks.append((token.NUMBER, "\'%s\'" % sarg(l)))
	elif op == 'asm':
		print sarg(l)
	elif op == 'offset':
		toks.pop()
		(t, v) = l[0]
		toks.append((token.NUMBER, "(.+%s)" % v))
	elif pseudo.has_key(op) or res.has_key(op) or op == 'incbin':
		toks.pop()
		if len(toks) == 0:
			gpseudo(op, l, None, c)
		else:
			litsym = "l%d" % nextlab()
			gpseudo(op, l, litsym, c)
			toks.append((token.NAME, litsym))
	else:
		toks.append(('list', l))

#-------------------------------------------------------------------------------
# walk the python parse tree, generating assembly as we go.
# - top level productions intersperse section and control
#   flow information with the generated code
# - mid level productions recur on each of their leaves
# - low level productions rewrite the token stack: ('memref', 'list')
# - bottom level productions produce the token stack in RPN
#-------------------------------------------------------------------------------

def prtup(tup,c):
        """ process a node in the parse tree """
	(code, sect, flags) = c
	x = tup[0]

	if not x:
		return
	elif token.ISTERMINAL(x):
		#### bottom level ####
		toks.append(prterm(tup,c))
		# print "#\t", token.tok_name[x], "\t", tup[1]
		return
	# print "#.." + symbol.sym_name[x]

	def prtups(ts,c):
		for t in ts:
			prtup(t,c)

	###### top level ######
	if   x == symbol.file_input:
		section(sect)
		prtups(tup[1:-1],c)
	elif x == symbol.suite:
		prtups(tup[3:-1],c)
	elif x == symbol.funcdef:
		(_,_,(_,name),_,_,body) = tup
		section(sect+1)
		if sect+1: print "%s:" % name
		else:      print "%s:" % glob(name)
		prtup(body,(1, sect+1, None))
		print "ret"
		section(sect)
	elif x == symbol.if_stmt:
		prtup(tup[2],c)
		test = {'jnz':'jz', 'jz':'jnz',
			'jb':'jae', 'jae':'jb',
			'ja':'jbe', 'jbe':'ja'}[gcond()]
		l0 = nextlab()
		print "%s L%d" % (test, l0)
		prtup(tup[4],c)
		if len(tup) != 8:
			print "L%d:" % l0
		else:
			l1 = nextlab()
			print "jmp L%d\nL%d:" % (l1, l0)
			prtup(tup[7],c)
			print "L%d:" % l1
		tokreset()
	elif x == symbol.while_stmt:
		l0, l1 = nextlab(), nextlab()
		print "jmp L%d" % l0
		print "L%d:" % l1
		prtup(tup[4],c)
		print "L%d:" % l0
		prtup(tup[2],c)
		print "%s L%d" % (gcond(), l1)
		tokreset()
	###### mid level ######
	elif x == symbol.simple_stmt:
		for t in tup[1:]:
			(x,y) = t
			if x not in [token.SEMI, token.NEWLINE]:
				prtup(t,c)
				if len(toks):
					if code: gcode()
					else: gdata()
			tokreset()
	elif x in [symbol.arglist, symbol.subscriptlist, symbol.power]:
		prtups(tup[1:],c)
	###### low level ######
	elif x == symbol.trailer:
		kludge = toks[-1][1]
		prtups(tup[1:],(code,sect,kludge))
		if tokpeek(token.RPAR):
			toks.pop()
			prlist(kludge, c)
	elif x == symbol.atom and tup[1][0] == token.LSQB:
		prtup(tup[2],(code,sect,'memref'))
		toks.append(mref(toks.pop()))
	else:
		if flags == 'memref' and mrewrite(tup, c):
			return
	###### bottom level ######
		prtup(tup[1],c)
		if len(tup) > 2:
			prtup(tup[3],c)
			prtup(tup[2],c)
			if len(tup) > 4:
				prtup(tup[5],c)
				prtup(tup[4],c)

################################################################################

if len(sys.argv) == 1 or sys.argv[1] == '-':
	src = sys.stdin
	dir = "."
else:
	src = open(sys.argv[1], "r")
	dir = os.path.dirname(sys.argv[1])
	print ".file \"%s\"" % os.path.basename(sys.argv[1])

try:
	prtup(parser.suite(src.read()).totuple(),(0,-1,None))
except parser.ParserError, err:
	print ".err"
	sys.stderr.write("parser.suite() gave up.  Where?  Who knows.\n")
	sys.stderr.write("Is it any more informative in >2.1 Pythons?\n")
